import numpy as np
import emcee
import celerite
import matplotlib.pyplot as plt
from celerite import terms
from scipy.optimize import minimize
from multiprocessing import Pool
from time import perf_counter
start=perf_counter()

def DRW_process(t,tau,SF,m):
    r=np.diff(t)/tau
    ls=[np.random.normal(m,SF/1.414,1)[0]]
    for i in range(len(t)-1):
        if r[i]<0:print('Error:时间序列未排序');return 
        stdev=(1-np.exp(-2*r[i]))**0.5*SF/1.414
        loc=ls[i]*np.exp(-r[i])+m*(1-np.exp(-r[i]))
        ls.append(np.random.normal(loc,stdev,1)[0])
    return np.array(ls)

def DRW_fit(t,s,err,mean):
     # Define a cost function
    def mle(params, y, gp):
        gp.set_parameter_vector(params)
        return -gp.log_likelihood(y)

    def max_ap(params, y, gp):
        gp.set_parameter_vector(params)
        return -gp.log_likelihood(y)+0.5*params[0]-params[1]*0.5    

    def log_probability(params):
        gp.set_parameter_vector(params)
        lp=gp.log_prior()
        loga=gp.get_parameter_dict().get('kernel:log_a')
        logc=gp.get_parameter_dict().get('kernel:log_c')
        if not np.isfinite(lp):
            return -np.inf
        return gp.log_likelihood(s)+lp-0.5*loga+logc*0.5
    
    # Set up the GP model
    bounds=dict(log_a=(2*np.log(0.02),2*np.log(0.7)),log_c=(-np.log(5000),0))
    kernel=terms.RealTerm(log_a=np.log(0.1414),log_c=-np.log(400),bounds=bounds)
    gp=celerite.GP(kernel,mean=mean,fit_mean=True)
    gp.compute(t,err)

    initial_params = gp.get_parameter_vector()
    soln = minimize(mle,initial_params,method="L-BFGS-B",args=(s,gp))
    gp.set_parameter_vector(soln.x)

    rt=np.exp(-gp.get_parameter_dict().get('kernel:log_c'))
    rs=np.exp(gp.get_parameter_dict().get('kernel:log_a')/2)
    #MCMC
    initial=np.array(soln.x)
    ndim, nwalkers=len(initial),16
    sampler=emcee.EnsembleSampler(nwalkers,ndim,log_probability)
    #print("Running burn-in...")
    p0=initial+1e-4*np.random.randn(nwalkers,ndim)
    p0,lp,_=sampler.run_mcmc(p0,125)
    #print("Running production...")
    sampler.reset()
    sampler.run_mcmc(p0,500)
    #print(sampler.get_chain().shape)
    
    lt_chain=-sampler.flatchain[:,1]
    ls_chain=sampler.flatchain[:,0]/2
    '''
    t_e=np.mean(np.exp(lt_chain));s_e=np.mean(np.exp(ls_chain))
    t_m=np.exp(np.median(lt_chain));s_m=np.exp(np.median(ls_chain))

    bounds=dict(log_a=(2*np.log(0.02),2*np.log(0.7)),log_c=(-np.log(5000),0))
    kernel=terms.RealTerm(log_a=np.log(0.1414),log_c=-np.log(400),bounds=bounds)
    gp=celerite.GP(kernel,mean=mean,fit_mean=True)
    gp.compute(t,err)
 
    initial_params = gp.get_parameter_vector()
    soln=minimize(max_ap,initial_params,method="L-BFGS-B",args=(s,gp))
    gp.set_parameter_vector(soln.x)
 
    t_map=np.exp(-gp.get_parameter_dict().get('kernel:log_c'))
    s_map=np.exp(gp.get_parameter_dict().get('kernel:log_a')/2)
    return lt_chain,ls_chain
    '''
    return ls_chain,lt_chain
tau=300
epochs=60
t=np.sort(np.random.uniform(0,2922,epochs))
y=DRW_process(t,tau,0.2,0)
e=np.full(len(t),0.03)
s=np.random.normal(y,e)
ls,lt=DRW_fit(t,s,e,np.mean(s))
a=lt.reshape(160,50)
em=np.std(np.mean(a,axis=0))
ee=np.std(lt)
print(em,ee,(em/ee)**2)

'''
plt.plot(lt)
plt.show()
plt.plot(ls)
plt.show()
'''
